library(tidyverse)
library(rstan)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = T)

args <- commandArgs(trailingOnly = TRUE)

if (!any(args == "vb_input")) {
  stop("Provide a vb estimation as input with the keyword vb_input")
} else {
  vb_input_idx <- which(args == "vb_input")

  filename <- args[vb_input_idx + 1]

  stan_data <- readRDS(filename)$data
  vb_fit <- readRDS(filename)$vb_fit

  filename <- gsub("vb", "MAP", filename)

  median_vb_estimates <- vb_fit@model_pars %>%
    (function(x) x[x != "lp__"]) %>%
    lapply(function(parname) {
      summary(vb_fit, pars = parname) %>%
        .$summary %>%
        as_tibble() %>%
        .$`50%` %>%
        list(pp = .) %>%
        setNames(parname)
    }) %>%
    Reduce(c, .)

  Poisson_model_age_sex_effect_nu_shrunk_MAP <- "
data {
  int<lower=0> N;         // number of data points
  int<lower=0> Nmunicipalities;
  int<lower=0> municipality[N];
  int<lower=0> Ymtas[N];
  int<lower=0> Nmtas[N];
  int<lower=0> covid[N]; //
  int<lower=0> Nage_class;         //
  real age_class_scaled[N];
  int<lower=0> sex[N]; //
  real<lower=0, upper=1> rho_as[N]; //
  real<lower=0> lambda;
}
parameters {
  real log_h0m_raw[Nmunicipalities];
  real<lower=0, upper=1> nu_m[Nmunicipalities];
  real mu_h;
  real beta_age;
  real beta_sex;
  real<lower=0> sigma_h;
}
  transformed parameters {
    real log_h0m[Nmunicipalities];

    for (m in 1:Nmunicipalities){
      log_h0m[m] = log_h0m_raw[m]*sigma_h + mu_h;
    }
  }
model {
  real lambda_eff[N];
  mu_h ~ normal(-5, 5);
  sigma_h ~ normal(1, 20);
  beta_age ~ normal(0, 2);
  beta_sex ~ normal(0, 2);

  nu_m ~ exponential(lambda);
  log_h0m_raw ~ std_normal();

  for (i in 1:N) {
      lambda_eff[i] = Nmtas[i]*(exp(log_h0m[municipality[i]] + beta_age*age_class_scaled[i] + beta_sex*sex[i]) + covid[i]*nu_m[municipality[i]]*rho_as[i]);
    }
//  print(lambda_eff);
//  print(Ymtas);

  Ymtas ~ poisson(lambda_eff);
}
"
  cat(file = "Scripts/Poisson_model_age_sex_effect_nu_shrunk_MAP.stan", Poisson_model_age_sex_effect_nu_shrunk_MAP)

  system.time(m <- stan_model(file = "Scripts/Poisson_model_age_sex_effect_nu_shrunk_MAP.stan"))

  stan_data$lambda <- median_vb_estimates$lambda

  MLE_fit <- optimizing(object = m, data = stan_data, verbose = T, init = median_vb_estimates, iter = 100000)

  print(paste("saving to", filename))

  saveRDS(MLE_fit, file = filename)
}
